{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "\u003ca href=\"https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/shakespeare_with_tpuestimator.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "N6ZDpd9XzFeN"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "KUu4vOt5zI9d"
      },
      "outputs": [],
      "source": [
        "# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2lj8XY0R-pIu"
      },
      "source": [
        ""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "edfbxDDh2AEs"
      },
      "source": [
        "## Predict Shakespeare with Cloud TPUs and TPUEstimator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "j4ShJE_voaYC"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This example uses TPUEstimator to build a language model and train it on a Cloud TPU. This language model predicts the next character of text given the text so far. The trained model can generate new snippets of text that read in a similar style to the text training data.\n",
        "\n",
        "The model trains for 2000 steps and completes in approximately 5 minutes.\n",
        "\n",
        "This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select **File \u003e View on GitHub**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dgAHfQtuhddd"
      },
      "source": [
        "## Learning objectives\n",
        "\n",
        "In this Colab, you will learn how to:\n",
        "*   Build a simple 3 layer, forward Long Short-Term Memory (LSTM) language model.\n",
        "*   Provide a _model function_ to train the model for TPUEstimator.\n",
        "*   Run the model forward and see how well it predicts the next character.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QrprJD-R-410"
      },
      "source": [
        "## Instructions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_I0RdnOSkNmi"
      },
      "source": [
        "\u003ch3\u003e  \u0026nbsp;\u0026nbsp;Train on TPU\u0026nbsp;\u0026nbsp; \u003ca href=\"https://cloud.google.com/tpu/\"\u003e\u003cimg valign=\"middle\" src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"\u003e\u003c/a\u003e\u003c/h3\u003e\n",
        "\n",
        "   1. Create a Cloud Storage bucket for your TensorBoard logs at http://console.cloud.google.com/storage.\n",
        "   1. On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n",
        "   1. Click Runtime again and select **Runtime \u003e Run All** (Watch out: the initial authentication step for this notebook  requires that you click on `use_tpu` and supply a bucket name as input). You can also run the cells manually with Shift-ENTER.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PY5kGr0MQqqw"
      },
      "source": [
        "TPUs are located in Google Cloud, for optimal performance, they read data directly from Google Cloud Storage (GCS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Lvo0t7XVIkWZ"
      },
      "source": [
        "## Data, model, and training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xzpUtDMqmA-x"
      },
      "source": [
        "For this exercise, you train the network using the combined works of Shakespeare to create a play-generating robot.\n",
        "\n",
        "\n",
        "The network outputs something Shakespeare-esque:\n",
        "\n",
        "___\n",
        "\u003cblockquote\u003e\n",
        "Loves that led me no dumbs lack her Berjoy's face with her to-day.   \n",
        "The spirits roar'd; which shames which within his powers  \n",
        "Which tied up remedies lending with occasion,  \n",
        "A loud and Lancaster, stabb'd in me   \n",
        "\tUpon my sword for ever: 'Agripo'er, his days let me free.  \n",
        "\tStop it of that word, be so: at Lear,  \n",
        "\tWhen I did profess the hour-stranger for my life,  \n",
        "\tWhen I did sink to be cried how for aught;  \n",
        "\tSome beds which seeks chaste senses prove burning;  \n",
        "But he perforces seen in her eyes so fast;  \n",
        "And _  \n",
        "\u003c/blockquote\u003e\n",
        "___\n",
        "\n",
        "To generate your own faux-Shakespeare, you begin with a data generator. The training data for the model is snippets from a text file; the _target_ snippet is offset by one character.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6FQiIHs_B2hS"
      },
      "source": [
        "### Authentication"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "myGkRWgYWD2g"
      },
      "outputs": [],
      "source": [
        "# !rm /content/adc.json"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "IcZkpa-e-Fas"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "import os\n",
        "import pprint\n",
        "import re\n",
        "import time\n",
        "import tensorflow as tf\n",
        "\n",
        "\n",
        "use_tpu = True #@param {type:\"boolean\"}\n",
        "bucket = '' #@param {type:\"string\"}\n",
        "\n",
        "assert bucket, 'Must specify an existing GCS bucket name'\n",
        "print('Using bucket: {}'.format(bucket))\n",
        "\n",
        "if use_tpu:\n",
        "    assert 'COLAB_TPU_ADDR' in os.environ, 'Missing TPU; did you request a TPU in Notebook Settings?'\n",
        "\n",
        "MODEL_DIR = 'gs://{}/{}'.format(bucket, time.strftime('tpuestimator-lstm/%Y-%m-%d-%H-%M-%S'))\n",
        "print('Using model dir: {}'.format(MODEL_DIR))\n",
        "\n",
        "from google.colab import auth\n",
        "auth.authenticate_user()\n",
        "\n",
        "if 'COLAB_TPU_ADDR' in os.environ:\n",
        "  TF_MASTER = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])\n",
        "  \n",
        "  # Upload credentials to TPU.\n",
        "  with tf.Session(TF_MASTER) as sess:    \n",
        "    with open('/content/adc.json', 'r') as f:\n",
        "      auth_info = json.load(f)\n",
        "    tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)\n",
        "  # Now credentials are set for all future sessions on this TPU.\n",
        "else:\n",
        "  TF_MASTER=''\n",
        "\n",
        "with tf.Session(TF_MASTER) as session:\n",
        "  pprint.pprint(session.list_devices())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Qew6Qt0-kGzO"
      },
      "source": [
        "### Training data\n",
        "\n",
        "You can use a `tf.data` pipeline to feed input data to your Estimator.  The goal for this exercise is to have the model predict the next character, so you need to feed sequences from a supplied dataset where the source is offset from the target by one character.\n",
        "\n",
        "Note that the model uses `tf.contrib.data.enumerate_dataset()` and  `tf.contrib.stateless.stateless_random_uniform` to generate deterministic uniform samples.  This, combined with the setting of `RunConfig.tf_random_seed` guarantees that every run of the model will have the same behavior."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "E3V4V-Jxmuv3"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "!wget --show-progress --continue -O /content/shakespeare.txt http://www.gutenberg.org/files/100/100-0.txt\n",
        "\n",
        "SHAKESPEARE_TXT = '/content/shakespeare.txt'\n",
        "RANDOM_SEED = 42  # An arbitrary choice.\n",
        "\n",
        "def transform(txt):\n",
        "  return np.asarray([ord(c) for c in txt], dtype=np.int32)\n",
        "\n",
        "def input_fn(params):\n",
        "  \"\"\"Return a dataset of source and target sequences for training.\"\"\"\n",
        "  batch_size = params['batch_size']\n",
        "  print('Batch size: {}'.format(batch_size))\n",
        "  seq_len = params['seq_len']\n",
        "  with tf.gfile.GFile(params['source_file'], 'r') as f:\n",
        "    txt = f.read()\n",
        "    txt = ''.join([x for x in txt if ord(x) \u003c 128])\n",
        "    \n",
        "  tf.logging.info('Sample text: %s', txt[10000:10100])\n",
        "  source = tf.constant(transform(txt), dtype=tf.int32)\n",
        "  ds = tf.data.Dataset.from_tensors(source)\n",
        "  ds = ds.repeat()\n",
        "  ds = ds.apply(tf.contrib.data.enumerate_dataset())\n",
        "\n",
        "  def _select_seq(offset, src):\n",
        "    idx = tf.contrib.stateless.stateless_random_uniform(\n",
        "        [1], seed=[RANDOM_SEED, offset], dtype=tf.float32)[0]\n",
        "\n",
        "    max_start_offset = len(txt) - seq_len\n",
        "    idx = tf.cast(idx * max_start_offset, tf.int32)\n",
        "    print(idx)\n",
        "    \n",
        "    return {\n",
        "        'source': tf.reshape(src[idx:idx + seq_len], [seq_len]),\n",
        "        'target': tf.reshape(src[idx + 1:idx + seq_len + 1], [seq_len])\n",
        "    }\n",
        "\n",
        "  ds = ds.map(_select_seq)\n",
        "  ds = ds.batch(batch_size, drop_remainder=True)\n",
        "  ds = ds.prefetch(2)\n",
        "  return ds\n",
        "\n",
        "tf.reset_default_graph()\n",
        "tf.set_random_seed(0)\n",
        "with tf.Session() as session:\n",
        "  ds = input_fn({'batch_size': 1, 'seq_len': 10, 'source_file': SHAKESPEARE_TXT})\n",
        "  features = session.run(ds.make_one_shot_iterator().get_next())\n",
        "  print(features['source'])\n",
        "  print(features['target'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Bbb05dNynDrQ"
      },
      "source": [
        "### Build the model\n",
        "\n",
        "Now that you have some data, you can define your model. This example uses a simple 3 layer, forward Long Short-Term Memory (LSTM) language model.\n",
        "\n",
        "The difference between this model and a CPU/GPU model is that you must specify a static `shape` for the model's input. This allows TensorFlow to infer the shape of the model and to satisfy the XLA compiler's static shape requirement."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "yLEM-fLJlEEt"
      },
      "outputs": [],
      "source": [
        "EMBEDDING_DIM = 1024\n",
        "\n",
        "# Construct a 2-layer LSTM\n",
        "def _lstm(inputs, batch_size, initial_state=None):\n",
        "  def _make_cell(layer_idx):\n",
        "    with tf.variable_scope('lstm/%d' % layer_idx,):\n",
        "      return tf.nn.rnn_cell.LSTMCell(\n",
        "          num_units=EMBEDDING_DIM,\n",
        "          state_is_tuple=True,\n",
        "          reuse=tf.AUTO_REUSE,\n",
        "      )\n",
        "\n",
        "  cell = tf.nn.rnn_cell.MultiRNNCell([\n",
        "      _make_cell(0), \n",
        "      _make_cell(1),\n",
        "  ])\n",
        "  if initial_state is None:\n",
        "    initial_state = cell.zero_state(batch_size, tf.float32)\n",
        "\n",
        "  outputs, final_state = tf.contrib.recurrent.functional_rnn(\n",
        "      cell, inputs, initial_state=initial_state, use_tpu=use_tpu)\n",
        "  return outputs, final_state\n",
        "\n",
        "\n",
        "def lstm_model(seq, initial_state=None):\n",
        "  with tf.variable_scope('lstm', \n",
        "                         initializer=tf.orthogonal_initializer,\n",
        "                         reuse=tf.AUTO_REUSE):\n",
        "    batch_size = seq.shape[0]\n",
        "    seq_len = seq.shape[1]\n",
        "\n",
        "    embedding_params = tf.get_variable(\n",
        "        'char_embedding', \n",
        "        initializer=tf.orthogonal_initializer(seed=0),\n",
        "        shape=(256, EMBEDDING_DIM), dtype=tf.float32)\n",
        "\n",
        "    embedding = tf.nn.embedding_lookup(embedding_params, seq)\n",
        "\n",
        "    lstm_output, lstm_state = _lstm(\n",
        "        embedding, batch_size, initial_state=initial_state)\n",
        "\n",
        "    # Apply a single dense layer to the output of our LSTM to predict\n",
        "    # our final characters.  This looks awkward as we have to flatten\n",
        "    # our input to 2 dimensions before applying the dense layer.\n",
        "    flattened = tf.reshape(lstm_output, [-1, EMBEDDING_DIM])\n",
        "    logits = tf.layers.dense(flattened, 256, name='logits',)\n",
        "    logits = tf.reshape(logits, [-1, seq_len, 256])\n",
        "    return logits, lstm_state"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "j0ZYOd07qJws"
      },
      "source": [
        "### Train the model\n",
        "\n",
        "Since this example uses TPUEstimator, you must provide a _model function_ to train the model. The model function specifies how to train, evaluate and run inference (predictions) on your model.\n",
        "\n",
        "Each part of the model function is covered in turn below. The first part is the training step.\n",
        "\n",
        "* Feed your source tensor to your LSTM model.\n",
        "* Compute the cross entropy loss to train it to better predict the target tensor.\n",
        "* Use the `RMSPropOptimizer` to optimize your network.\n",
        "* Wrap it with the `CrossShardOptimizer` which lets you use multiple TPU cores to train.  \n",
        "\n",
        "Finally, return a `TPUEstimatorSpec` to indicate how TPUEstimator should train your model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1b5E8ZSUrCBk"
      },
      "outputs": [],
      "source": [
        "def train_fn(source, target):\n",
        "  logits, lstm_state = lstm_model(source)\n",
        "  batch_size = source.shape[0]\n",
        "  \n",
        "  loss = tf.reduce_mean(\n",
        "      tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
        "          labels=target, logits=logits))\n",
        "\n",
        "  optimizer = tf.train.AdamOptimizer(learning_rate=0.001)\n",
        "  if TF_MASTER:\n",
        "    optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)\n",
        "  train_op = optimizer.minimize(loss, tf.train.get_global_step())\n",
        "  return tf.contrib.tpu.TPUEstimatorSpec(\n",
        "      mode=tf.estimator.ModeKeys.TRAIN,\n",
        "      loss=loss,\n",
        "      train_op=train_op,\n",
        "  )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ro-Y2oG27l4r"
      },
      "source": [
        "### Evaluate the model\n",
        "\n",
        "The evaluation step is simpler: you run the model forward and check how well it predicts the next character. Returning a `TPUEstimatorSpec` in this section tells TPUEstimator how to evaluate the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gycj1IPp63Fj"
      },
      "outputs": [],
      "source": [
        "def eval_fn(source, target):\n",
        "  logits, _ = lstm_model(source)\n",
        "  def metric_fn(labels, logits):\n",
        "    labels = tf.cast(labels, tf.int64)\n",
        "    return {\n",
        "        'recall@1': tf.metrics.recall_at_k(labels, logits, 1),\n",
        "        'recall@5': tf.metrics.recall_at_k(labels, logits, 5)\n",
        "    }\n",
        "\n",
        "  eval_metrics = (metric_fn, [target, logits])\n",
        "  return tf.contrib.tpu.TPUEstimatorSpec(\n",
        "      mode=tf.estimator.ModeKeys.EVAL, \n",
        "      loss=loss, \n",
        "      eval_metrics=eval_metrics)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rHenJneu78Sy"
      },
      "source": [
        "### Compute predictions\n",
        "\n",
        "The following step is not TPU-specific. It uses the input tensor as a _seed_ for the model, then uses a TensorFlow loop to sample characters from the model and return a result."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "-bRhg5Tx8PLr"
      },
      "outputs": [],
      "source": [
        "def predict_fn(source):\n",
        "  # Seed the model with our initial array\n",
        "  batch_size = source.shape[0]\n",
        "  logits, lstm_state = lstm_model(source)\n",
        "\n",
        "  def _body(i, state, preds):\n",
        "    \"\"\"Body of our prediction loop: predict the next character.\"\"\"\n",
        "    cur_preds = preds.read(i)\n",
        "    next_logits, next_state = lstm_model(\n",
        "        tf.cast(tf.expand_dims(cur_preds, -1), tf.int32), state)\n",
        "\n",
        "    # pull out the last (and only) prediction.\n",
        "    next_logits = next_logits[:, -1]\n",
        "    next_pred = tf.multinomial(\n",
        "        next_logits, num_samples=1, output_dtype=tf.int32)[:, 0]\n",
        "    preds = preds.write(i + 1, next_pred)\n",
        "    return (i + 1, next_state, preds)\n",
        "\n",
        "  def _cond(i, state, preds):\n",
        "    del state\n",
        "    del preds\n",
        "\n",
        "    # Loop until `predict_len - 1`: preds[0] is the initial state and we\n",
        "    # write to `i + 1` on each iteration.\n",
        "    return tf.less(i, predict_len - 1)\n",
        "\n",
        "  next_pred = tf.multinomial(\n",
        "      logits[:, -1], num_samples=1, output_dtype=tf.int32)[:, 0]\n",
        "\n",
        "  i = tf.constant(0, dtype=tf.int32)\n",
        "\n",
        "  predict_len = 500\n",
        "\n",
        "  # compute predictions as [seq_len, batch_size] to simplify indexing/updates\n",
        "  pred_var = tf.TensorArray(\n",
        "      dtype=tf.int32,\n",
        "      size=predict_len,\n",
        "      dynamic_size=False,\n",
        "      clear_after_read=False,\n",
        "      element_shape=(batch_size,),\n",
        "      name='prediction_accumulator',\n",
        "  )\n",
        "\n",
        "  pred_var = pred_var.write(0, next_pred)\n",
        "  _, _, final_predictions = tf.while_loop(_cond, _body,\n",
        "                                          [i, lstm_state, pred_var])\n",
        "\n",
        "  # reshape back to [batch_size, predict_len] and cast to int32\n",
        "  final_predictions = final_predictions.stack()\n",
        "  final_predictions = tf.transpose(final_predictions, [1, 0])\n",
        "  final_predictions = tf.reshape(final_predictions, (batch_size, predict_len))\n",
        "\n",
        "  return tf.contrib.tpu.TPUEstimatorSpec(\n",
        "      mode=tf.estimator.ModeKeys.PREDICT, \n",
        "      predictions={'predictions': final_predictions})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "IwRTkF4l8e3M"
      },
      "source": [
        "### Build the model function\n",
        "\n",
        "To build the model function that TPUEstimator expects, combine the helper functions as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "D5e8TD4q8kq2"
      },
      "outputs": [],
      "source": [
        "def model_fn(features, labels, mode, params):\n",
        "  if mode == tf.estimator.ModeKeys.TRAIN:\n",
        "    return train_fn(features['source'], features['target'])\n",
        "  if mode == tf.estimator.ModeKeys.EVAL:\n",
        "    return eval_fn(features['source'], features['target'])\n",
        "  if mode == tf.estimator.ModeKeys.PREDICT:\n",
        "    return predict_fn(features['source'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VzBYDJI0_Tfm"
      },
      "source": [
        "### Run the model\n",
        "\n",
        "Use the following boilerplate to specify a TPU worker, then you are ready to train your model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "CS9no3m_rCf0"
      },
      "outputs": [],
      "source": [
        "def _make_estimator(num_shards, use_tpu=True):\n",
        "  config = tf.contrib.tpu.RunConfig(\n",
        "      tf_random_seed=RANDOM_SEED,\n",
        "      master=TF_MASTER,\n",
        "      model_dir=MODEL_DIR,\n",
        "      save_checkpoints_steps=5000,\n",
        "      tpu_config=tf.contrib.tpu.TPUConfig(\n",
        "          num_shards=num_shards, iterations_per_loop=100))\n",
        "\n",
        "  estimator = tf.contrib.tpu.TPUEstimator(\n",
        "      use_tpu=use_tpu,\n",
        "      model_fn=model_fn, config=config,\n",
        "      train_batch_size=1024,\n",
        "      eval_batch_size=1024,\n",
        "      predict_batch_size=128,\n",
        "      params={'seq_len': 100, 'source_file': SHAKESPEARE_TXT},\n",
        "  )\n",
        "  return estimator\n",
        "\n",
        "\n",
        "# Use all 8 cores for training\n",
        "estimator = _make_estimator(num_shards=8, use_tpu=use_tpu)\n",
        "estimator.train(\n",
        "    input_fn=input_fn,\n",
        "    max_steps=2000,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TCBtcpZkykSf"
      },
      "source": [
        "### Run predictions with the model\n",
        "\n",
        "Now that your model is trained, you can run predictions through it to generate faux-Shakespeare. Use the seed sentence to get your model started, then sample 500 characters from it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tU7M-EGGxR3E"
      },
      "outputs": [],
      "source": [
        "def _seed_input_fn(params):\n",
        "  del params\n",
        "  seed_txt = 'Looks it not like the king?'\n",
        "  seed = transform(seed_txt)\n",
        "  seed = tf.constant(seed.reshape([1, -1]), dtype=tf.int32)\n",
        "  # Predict must return a Dataset, not a Tensor.\n",
        "  return tf.data.Dataset.from_tensors({'source': seed})\n",
        "\n",
        "# Use 1 core for prediction since we're only generating a single element batch\n",
        "estimator = _make_estimator(num_shards=1, use_tpu=False)\n",
        "\n",
        "idx = next(estimator.predict(input_fn=_seed_input_fn))['predictions']\n",
        "print(''.join([chr(i) for i in idx]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2a5cGsSTEBQD"
      },
      "source": [
        "## What's next\n",
        "\n",
        "* Learn about [Cloud TPUs](https://cloud.google.com/tpu/docs) that Google designed and optimized specifically to speed up and scale up ML workloads for training and inference and to enable ML engineers and researchers to iterate more quickly.\n",
        "* Explore the range of [Cloud TPU tutorials and Colabs](https://cloud.google.com/tpu/docs/tutorials) to find other examples that can be used when implementing your ML project.\n",
        "\n",
        "On Google Cloud Platform, in addition to GPUs and TPUs available on pre-configured [deep learning VMs](https://cloud.google.com/deep-learning-vm/),  you will find [AutoML](https://cloud.google.com/automl/)*(beta)* for training custom models without writing code and [Cloud ML Engine](https://cloud.google.com/ml-engine/docs/) which will allows you to run parallel trainings and hyperparameter tuning of your custom models on powerful distributed hardware.\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [
        "N6ZDpd9XzFeN"
      ],
      "name": "Predict Shakespeare with Cloud TPUs and TPUEstimator",
      "provenance": [],
      "toc_visible": true,
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
